{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "60KmTK7o6ppd"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/core/distributed_tuning\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/Gemma_Distributed_Fine_tuning_on_TPU.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/distributed_tuning.ipynb\"><img src=\"https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png\" height=\"32\" width=\"70\"/>Run in Kaggle</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/core/distributed_tuning.ipynb\"><img src=\"https://ai.google.dev/images/cloud-icon.svg\" width=\"40\" />Open in Vertex AI</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/distributed_tuning.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FUOiKRSF7jc1"
      },
      "source": [
        "# Distributed tuning with Gemma using Keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tdlq6K0znh3O"
      },
      "source": [
        "Gemma is a family of lightweight, state-of-the-art open models built from research and technology used to create Google Gemini models. Gemma can be further finetuned to suit specific needs. But Large Language Models, such as Gemma, can be very large in size and some of them may not fit on a sing accelerator for finetuning. In this case there are two general approaches for finetuning them:\n",
        "1. Parameter Efficient Fine-Tuning (PEFT), which seeks to shrink the effective model size by sacrificing some fidelity. LoRA falls in this category and the [Fine-tune Gemma models in Keras using LoRA](https://ai.google.dev/gemma/docs/core/lora_tuning) tutorial demonstrates how to finetune the Gemma 2B model `gemma_2b_en` with LoRA using KerasNLP on a single GPU.\n",
        "2. Full parameter finetuning with model parallelism. Model parallelism distributes a single model's weights across multiple devices and enables horizontal scaling. You can find out more about distributed training in this [Keras guide](https://keras.io/guides/distribution/).\n",
        "\n",
        "This tutorial walks you through using Keras with a JAX backend to finetune the Gemma 7B model with LoRA and model-parallism distributed training on Google's Tensor Processing Unit (TPU). Note that LoRA can be turned off in this tutorial for a slower but more accurate full-parameter tuning."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z-jBO5hmDwrc"
      },
      "source": [
        "## Using accelerators\n",
        "\n",
        "Technically you can use either TPU or GPU for this tutorial.\n",
        "\n",
        "### Notes on TPU environments\n",
        "\n",
        "Google has 3 products that provide TPUs:\n",
        "* [Colab](https://colab.sandbox.google.com/) provides TPU v2 for free, which is sufficient for this tutorial.\n",
        "* [Kaggle](https://www.kaggle.com/) offers TPU v3 for free and they also work for this tutorial.\n",
        "* [Cloud TPU](https://cloud.google.com/tpu) offers TPU v3 and newer generations. One way to set it up is:\n",
        "  1. Create a new [TPU VM](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms)\n",
        "  2. Set up [SSH port forwarding](https://cloud.google.com/solutions/connecting-securely#port-forwarding-over-ssh) for your intended Jupyter server port\n",
        "  3. Install Jupyter and start it on the TPU VM, then connect to Colab through \"Connect to a local runtime\"\n",
        "\n",
        "### Notes on multi-GPU setup\n",
        "\n",
        "Although this tutorial focuses on the TPU use case, you can easily adapt it for your own needs if you have a multi-GPU machine.\n",
        "\n",
        "If you prefer to work through Colab, it's also possible to provision a multi-GPU VM for Colab directly through \"Connect to a custom GCE VM\" in the Colab Connect menu.\n",
        "\n",
        "\n",
        "We will focus on using the **free TPU from Kaggle** here."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xry37wK5-Hrx"
      },
      "source": [
        "## Before you begin"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aKvTsIkL98BG"
      },
      "source": [
        "### Kaggle credentials\n",
        "\n",
        "Gemma models are hosted by Kaggle. To use Gemma, request access on Kaggle:\n",
        "\n",
        "- Sign in or register at [kaggle.com](https://www.kaggle.com)\n",
        "- Open the [Gemma model card](https://www.kaggle.com/models/google/gemma) and select _\"Request Access\"_\n",
        "- Complete the consent form and accept the terms and conditions\n",
        "\n",
        "Then, to use the Kaggle API, create an API token:\n",
        "\n",
        "- Open the [Kaggle settings](https://www.kaggle.com/settings)\n",
        "- Select _\"Create New Token\"_\n",
        "- A `kaggle.json` file is downloaded. It contains your Kaggle credentials\n",
        "\n",
        "Run the following cell and enter your Kaggle credentials when asked."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lKoW-nhE-gNO"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8c9b78140c8943edbd6191da7141650b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "VBox(children=(HTML(value='<center> <img\\nsrc=https://www.kaggle.com/static/images/site-logo.png\\nalt=\\'Kaggle…"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# If you are using Kaggle, you don't need to login again.\n",
        "!pip install ipywidgets\n",
        "import kagglehub\n",
        "\n",
        "kagglehub.login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V8cYitUVSgP0"
      },
      "source": [
        "An alternative way is to set KAGGLE_USERNAME and KAGGLE_KEY in your environment if kagglehub.login() doesn't work for you."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AO7a1Q4Yyc9Z"
      },
      "source": [
        "## Installation\n",
        "\n",
        "Install Keras and KerasNLP with the Gemma model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WWEzVJR4Fx9g"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U keras-nlp\n",
        "# Work around an import error with tensorflow-hub. The library is not used.\n",
        "!pip install -q -U tensorflow-hub\n",
        "# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.\n",
        "!pip install -q -U tensorflow-cpu tensorflow-text\n",
        "# Install keras 3 last. See https://keras.io/getting_started for details.\n",
        "!pip install -q -U keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fr9VnPm7FoMf"
      },
      "source": [
        "### Set up Keras JAX backend"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lbZsUvfhwL2D"
      },
      "source": [
        "Import JAX and run a sanity check on TPU. Kaggle offers TPUv3-8 devices which have 8 TPU cores with 16GB of memory each."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BK4MpHLKGujb"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
              " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
              " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
              " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
              " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
              " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
              " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
              " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import jax\n",
        "\n",
        "jax.devices()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WEgg_OVIL2HY"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "# The Keras 3 distribution API is only implemented for the JAX backend for now\n",
        "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
        "# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation\n",
        "# overhead\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"0.9\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wo1xkzr62hXN"
      },
      "source": [
        "## Load model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kFCmWEKdMA_Y"
      },
      "outputs": [],
      "source": [
        "import keras\n",
        "import keras_nlp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bx3m8f1dB7nk"
      },
      "source": [
        "### Notes on mixed precision training on NVIDIA GPUs\n",
        "\n",
        "When training on NVIDIA GPUs, mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) can be used to speed up training with minimal effect on training quality. In most case, it is recommended to turn on mixed precision as it saves both memory and time. However, be aware that at small batch sizes, it can inflate memory usage by 1.5x (weights will be loaded twice, at half precision and full precision).\n",
        "\n",
        "For inference, half-precision (`keras.config.set_floatx(\"bfloat16\")`) will work and save memory while mixed-precision is not applicable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T0lHxEDX03gp"
      },
      "outputs": [],
      "source": [
        "# Uncomment the line below if you want to enable mixed precision training on GPUs\n",
        "# keras.mixed_precision.set_global_policy('mixed_bfloat16')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xrR8TpVS6uPs"
      },
      "source": [
        "To load the model with the weights and tensors distributed across TPUs, first create a new `DeviceMesh`. `DeviceMesh` represents a collection of hardware devices configured for distributed computation and was introduced in Keras 3 as part of the unified distribution API.\n",
        "\n",
        "The distribution API enables data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. It leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion. Check out more details in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7gxEkpUiP1Qf"
      },
      "outputs": [],
      "source": [
        "# Create a device mesh with (1, 8) shape so that the weights are sharded across\n",
        "# all 8 TPUs.\n",
        "device_mesh = keras.distribution.DeviceMesh(\n",
        "    (1, 8),\n",
        "    [\"batch\", \"model\"],\n",
        "    devices=keras.distribution.list_devices())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gTSJUwkC-7c6"
      },
      "source": [
        "`LayoutMap` from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys, for example, `token_embedding/embeddings` below, which are treated like regex to match tensor paths. Matched tensors are sharded with model dimensions (8 TPUs); others will be fully replicated."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Wgh8h0qQCcu"
      },
      "outputs": [],
      "source": [
        "model_dim = \"model\"\n",
        "\n",
        "layout_map = keras.distribution.LayoutMap(device_mesh)\n",
        "\n",
        "# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs\n",
        "layout_map[\"token_embedding/embeddings\"] = (model_dim, None)\n",
        "# Regex to match against the query, key and value matrices in the decoder\n",
        "# attention layers\n",
        "layout_map[\"decoder_block.*attention.*(query|key|value).*kernel\"] = (\n",
        "    model_dim, None, None)\n",
        "\n",
        "layout_map[\"decoder_block.*attention_output.*kernel\"] = (\n",
        "    model_dim, None, None)\n",
        "layout_map[\"decoder_block.*ffw_gating.*kernel\"] = (None, model_dim)\n",
        "layout_map[\"decoder_block.*ffw_linear.*kernel\"] = (model_dim, None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6n4Zlvk9ALhZ"
      },
      "source": [
        "`ModelParallel` allows you to shard model weights or activation tensors across all devcies on the `DeviceMesh`. In this case, some of the Gemma 7B model weights are sharded across 8 TPU chips according to the `layout_map` defined above. Now load the model in the distributed way."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bu48vUnbQj0p"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...\n",
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...\n",
            "Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...\n",
            "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...\n",
            "Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...\n",
            "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n"
          ]
        }
      ],
      "source": [
        "model_parallel = keras.distribution.ModelParallel(\n",
        "    layout_map=layout_map, batch_dim_name=\"batch\")\n",
        "\n",
        "keras.distribution.set_distribution(model_parallel)\n",
        "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_7b_en\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ORCOIawAvpZ1"
      },
      "source": [
        "Now verify that the model has been partitioned correctly. Let's take `decoder_block_1` as an example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DqT7TRHKvoMK"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>\n",
            "decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)\n",
            "decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)\n",
            "decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)\n",
            "decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)\n",
            "decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')\n",
            "decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)\n",
            "decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)\n",
            "decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)\n",
            "decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')\n"
          ]
        }
      ],
      "source": [
        "decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')\n",
        "print(type(decoder_block_1))\n",
        "for variable in decoder_block_1.weights:\n",
        "  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jc0ZzYIW0TSN"
      },
      "source": [
        "## Inference before finetuning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ClaTyBp3Tgr4"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "gemma_lm.generate(\"Best comedy movies in the 90s \", max_length=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tKUXYLW_0lAx"
      },
      "source": [
        "The model generates a list of great comedy movies from the 90s to watch. Now we finetune the Gemma model to change the output style."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IcPCXCwvXC7t"
      },
      "source": [
        "## Finetune with IMDB"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6MVJlsuSXCcf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...\u001b[0m\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "eec649e96f174bfaa22c0598addc0e3a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Dl Completed...: 0 url [00:00, ? url/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8c6f4e3c0b7d47589671b0759263bab5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Dl Size...: 0 MiB [00:00, ? MiB/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "974fa4e362684fe0ba8df0ab2d75b3c4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ef99c3a33d85433e986d715203917814",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b3962688721343bd8440dad1f1176296",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "abaa54e825ab4ac7a1ee78345039ca4c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "650324b2c42d488396bd2abec6baf636",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4c5e107ec0a8466d8370bb8b03df1e0a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7efcadb1ce53445697f34e263d41e623",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[1mDataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.\u001b[0m\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "b\"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.\""
            ]
          },
          "execution_count": 14,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "\n",
        "imdb_train = tfds.load(\n",
        "    \"imdb_reviews\",\n",
        "    split=\"train\",\n",
        "    as_supervised=True,\n",
        "    batch_size=2,\n",
        ")\n",
        "# Drop labels.\n",
        "imdb_train = imdb_train.map(lambda x, y: x)\n",
        "\n",
        "imdb_train.unbatch().take(1).get_single_element().numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N-04czr3XTTl"
      },
      "outputs": [],
      "source": [
        "# Use a subset of the dataset for faster training.\n",
        "imdb_train = imdb_train.take(2000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xiLW0SpI1PfC"
      },
      "source": [
        "Perform finetuning using [Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA). LoRA is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the full weights of the model and inserting a smaller number of new trainable weights into the model. Basically LoRA reparameterizes the larger full weight matrices by 2 smaller low-rank matrices AxB to train and this technique makes training much faster and more memory-efficient."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3o_Gi3v_jp7s"
      },
      "outputs": [],
      "source": [
        "# Enable LoRA for the model and set the LoRA rank to 4.\n",
        "gemma_lm.backbone.enable_lora(rank=4)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1-hQFy7hXWRl"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gemma_causal_lm_preprocessor\"</span>\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Tokenizer (type)                                   </span>┃<span style=\"font-weight: bold\">                                             Vocab # </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaTokenizer</span>)                   │                                             <span style=\"color: #00af00; text-decoration-color: #00af00\">256,000</span> │\n",
              "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
              "</pre>\n"
            ],
            "text/plain": [
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type)                                  \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m                                            Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m)                   │                                             \u001b[38;5;34m256,000\u001b[0m │\n",
              "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gemma_causal_lm\"</span>\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃<span style=\"font-weight: bold\"> Layer (type)                  </span>┃<span style=\"font-weight: bold\"> Output Shape              </span>┃<span style=\"font-weight: bold\">         Param # </span>┃<span style=\"font-weight: bold\"> Connected to               </span>┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)     │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>)        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>)              │               <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma_backbone                │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">3072</span>)        │   <span style=\"color: #00af00; text-decoration-color: #00af00\">8,548,748,288</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>],        │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GemmaBackbone</span>)               │                           │                 │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]            │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256000</span>)      │     <span style=\"color: #00af00; text-decoration-color: #00af00\">786,432,000</span> │ gemma_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>]       │\n",
              "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
              "</pre>\n"
            ],
            "text/plain": [
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
              "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                 \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape             \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m        Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to              \u001b[0m\u001b[1m \u001b[0m┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
              "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m)     │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)              │               \u001b[38;5;34m0\u001b[0m │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m)        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m)              │               \u001b[38;5;34m0\u001b[0m │ -                          │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ gemma_backbone                │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3072\u001b[0m)        │   \u001b[38;5;34m8,548,748,288\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m],        │\n",
              "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m)               │                           │                 │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]            │\n",
              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
              "│ token_embedding               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m)      │     \u001b[38;5;34m786,432,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m]       │\n",
              "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m)         │                           │                 │                            │\n",
              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">8,548,748,288</span> (31.85 GB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m8,548,748,288\u001b[0m (31.85 GB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">11,067,392</span> (42.22 MB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m11,067,392\u001b[0m (42.22 MB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">8,537,680,896</span> (31.81 GB)\n",
              "</pre>\n"
            ],
            "text/plain": [
              "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m8,537,680,896\u001b[0m (31.81 GB)\n"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).\n",
            "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.\n",
            "  warnings.warn(\"Some donated buffers were not usable:\"\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[1m2000/2000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m358s\u001b[0m 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<keras.src.callbacks.history.History at 0x7e9cac7f41c0>"
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Fine-tune on the IMDb movie reviews dataset.\n",
        "\n",
        "# Limit the input sequence length to 128 to control memory usage.\n",
        "gemma_lm.preprocessor.sequence_length = 128\n",
        "# Use AdamW (a common optimizer for transformer models).\n",
        "optimizer = keras.optimizers.AdamW(\n",
        "    learning_rate=5e-5,\n",
        "    weight_decay=0.01,\n",
        ")\n",
        "# Exclude layernorm and bias terms from decay.\n",
        "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n",
        "\n",
        "gemma_lm.compile(\n",
        "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer=optimizer,\n",
        "    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
        ")\n",
        "gemma_lm.summary()\n",
        "gemma_lm.fit(imdb_train, epochs=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CnpeavB4fZ7Y"
      },
      "source": [
        "Note that enabling LoRA reduces the number of trainable parameters significantly, from 7 billion to only 11 million."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lBiOKlAy2MAe"
      },
      "source": [
        "## Inference after finetuning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9yNyJ8CLXfw0"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "\"Best comedy movies in the 90s \\n\\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great.\""
            ]
          },
          "execution_count": 32,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "gemma_lm.generate(\"Best comedy movies in the 90s \", max_length=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "inqB1e_v0xP5"
      },
      "source": [
        "After finetuning, the model has learned the style of movie reviews and is now generating output in that style in the context of 90s comedy movies."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzKsCGIN0yX5"
      },
      "source": [
        "## What's next\n",
        "\n",
        "In this tutorial, you learned how to using KerasNLP JAX backend to finetune a Gemma model on the IMDb dataset in a distributed manner on the powerful TPUs. Here are a few suggestions for what else to learn:\n",
        "\n",
        "* Learn how to [get started with Keras Gemma](https://ai.google.dev/gemma/docs/get_started).\n",
        "* Learn how to [finetune the Gemma model on GPU](https://ai.google.dev/gemma/docs/core/lora_tuning)."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "distributed_tuning.ipynb",
      "toc_visible": true
    },
    "google": {
      "image_path": "/site-assets/images/marketing/gemma.png",
      "keywords": [
        "examples",
        "gemma",
        "python",
        "quickstart",
        "text"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
